{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Ternary Classifier\n",
    "\n",
    "Create a classifier model using the VGG16 architecture and train it to classify between three classes. Originally used with the arXiv images dataset and the classes of diagram, sensor, and unsure.\n",
    "\n",
    "Draws upon code from Fabian Offert: https://github.com/zentralwerkstatt/explain.ipynb/blob/master/train.ipynb"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import warnings\n",
    "warnings.filterwarnings('ignore')\n",
    "\n",
    "import numpy as np\n",
    "\n",
    "import os\n",
    "import sys\n",
    "import glob\n",
    "import random\n",
    "from IPython.display import clear_output, display\n",
    "import PIL.Image\n",
    "\n",
    "# from keras.applications.inception_v3 import InceptionV3, preprocess_input\n",
    "# from keras.applications.vgg16 import VGG16, preprocess_input\n",
    "# from keras.applications.densenet import DenseNet121 as DenseNet, preprocess_input\n",
    "from keras.applications.inception_resnet_v2 import InceptionResNetV2 as ResNet, preprocess_input\n",
    "\n",
    "from keras.models import Model, load_model\n",
    "from keras.layers import Dense, GlobalAveragePooling2D, AveragePooling2D, Dropout, Flatten, Conv2D, Activation\n",
    "from keras.preprocessing import image\n",
    "from keras.preprocessing.image import ImageDataGenerator, img_to_array # load_img\n",
    "from keras.optimizers import SGD, Adam\n",
    "from keras.callbacks import ModelCheckpoint\n",
    "\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# this seems to help with some GPU memory issues\n",
    "\n",
    "import tensorflow as tf\n",
    "\n",
    "config = tf.ConfigProto()\n",
    "config.gpu_options.allow_growth = True\n",
    "session = tf.Session(config=config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dir = os.path.expanduser(\"~/data/images/classification/labelled/train/\")\n",
    "val_dir = os.path.expanduser(\"~/data/images/classification/labelled/val/\")\n",
    "\n",
    "classes = 3\n",
    "batch_size = 32\n",
    "epochs = 2000\n",
    "save_iter = 50\n",
    "# save_iter = 25\n",
    "finetune = False\n",
    "weights = 'imagenet'\n",
    "size = 224 # VGG16\n",
    "# size = 299 # V3\n",
    "freeze = 164 # V3, up to and including mixed5\n",
    "freeze = 20 # VGG16, up to and including \n",
    "# cpath = 'diagram-sensor-unsure_vgg16-{epoch:02d}.hdf5'\n",
    "cpath = 'diagram-sensor-unsure_dnet-{epoch:02d}.hdf5'\n",
    "checkpoints_path = \"checkpoints/\"\n",
    "init_epoch = 25"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def add_new_last_layer_v3(base_model, classes):\n",
    "    x = base_model.output\n",
    "    x = GlobalAveragePooling2D()(x)\n",
    "    x = Dense(1024, activation='relu')(x)\n",
    "    predictions = Dense(classes, activation='softmax', name='Predictions')(x)\n",
    "    model = Model(input=base_model.input, output=predictions)\n",
    "    return model\n",
    "\n",
    "def setup_to_finetune(model):\n",
    "    for layer in model.layers[:freeze]:\n",
    "        layer.trainable = False\n",
    "    for layer in model.layers[freeze:]:\n",
    "        layer.trainable = True\n",
    "    model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])\n",
    "    \n",
    "def setup_to_train(model):\n",
    "    model.compile(optimizer=Adam(), loss='categorical_crossentropy', metrics=['accuracy'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_datagen = ImageDataGenerator(\n",
    "    preprocessing_function=preprocess_input,\n",
    "    rotation_range=30,\n",
    "    width_shift_range=0.2,\n",
    "    height_shift_range=0.2,\n",
    "    shear_range=0.2,\n",
    "    zoom_range=0.2,\n",
    "    horizontal_flip=True\n",
    ")\n",
    "\n",
    "val_datagen = ImageDataGenerator(\n",
    "    preprocessing_function=preprocess_input,\n",
    "    rotation_range=30,\n",
    "    width_shift_range=0.2,\n",
    "    height_shift_range=0.2,\n",
    "    shear_range=0.2,\n",
    "    zoom_range=0.2,\n",
    "    horizontal_flip=True\n",
    ")\n",
    "\n",
    "train_generator = train_datagen.flow_from_directory(\n",
    "    train_dir,\n",
    "    target_size=(size, size),\n",
    "    batch_size=batch_size,\n",
    ")\n",
    "\n",
    "validation_generator = val_datagen.flow_from_directory(\n",
    "    val_dir,\n",
    "    target_size=(size, size),\n",
    "    batch_size=batch_size,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "nb_train_samples = 7799\n",
    "nb_validation_samples = 1949"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Create model\n",
    "\n",
    "Train using stochastic gradient descent. Results were mixed using Adam."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We need to define the input shape, see \n",
    "# https://stackoverflow.com/questions/49043955/shape-of-input-to-flatten-is-not-fully-defined-got-none-none-64\n",
    "# base_model = VGG16(weights=None, input_shape=(size, size, 3), include_top=False)\n",
    "# base_model = InceptionV3(weights=weights, input_shape=(size, size, 3), include_top=False)\n",
    "# base_model = DenseNet(weights=None, input_shape=(size,size,3), include_top=False)\n",
    "base_model = ResNet(weights=None, input_shape=(size,size,3), include_top=False)\n",
    "base_model.summary()\n",
    "\n",
    "# freeze layers\n",
    "# for layer in base_model.layers:\n",
    "#     layer.trainable = False\n",
    "\n",
    "layer_dict = dict([(layer.name, n) for n, layer in enumerate(base_model.layers)])\n",
    "model = add_new_last_layer_v3(base_model, classes)\n",
    "# model.summary()\n",
    "\n",
    "# if finetune: setup_to_finetune(model)\n",
    "# else: setup_to_train(model)\n",
    "\n",
    "model.compile(optimizer=SGD(lr=0.0001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])\n",
    "# model.compile(optimizer=\"adam\", loss='categorical_crossentropy', metrics=['accuracy'])\n",
    "\n",
    "model.summary()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Using class weights to help account for the much lower numbers of sensors and unsure images. These are then weighted according to the frequency they appear in the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class_weights = {0: 1,\n",
    "                1: 18,\n",
    "                2: 14}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# callbacks = [\n",
    "#     EpochCheckpoint(checkpoints_path, every=25, startAt=),\n",
    "#     CSVLogger(\"log.csv\", separator=',', append=False)\n",
    "# ]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train classifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "history = model.fit_generator(\n",
    "    train_generator,\n",
    "    epochs=epochs,\n",
    "    validation_data=validation_generator,\n",
    "    class_weight=class_weights,\n",
    "    steps_per_epoch=nb_train_samples // batch_size,\n",
    "    validation_steps=nb_validation_samples // batch_size,\n",
    "    initial_epoch=init_epoch,\n",
    "    callbacks=[ModelCheckpoint(cpath, period=save_iter)])\n",
    "\n",
    "# keras.callbacks.ModelCheckpoint(filepath, monitor='val_loss', verbose=0, save_best_only=False, save_weights_only=False, mode='auto', period=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# plot training metrics\n",
    "\n",
    "# summarize history for accuracy\n",
    "plt.plot(history.history['acc'])\n",
    "plt.plot(history.history['val_acc'])\n",
    "plt.title('model accuracy')\n",
    "plt.ylabel('accuracy')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend(['train', 'test'], loc='upper left')\n",
    "plt.show()\n",
    "# summarize history for loss\n",
    "plt.plot(history.history['loss'])\n",
    "plt.plot(history.history['val_loss'])\n",
    "plt.title('model loss')\n",
    "plt.ylabel('loss')\n",
    "plt.xlabel('epoch')\n",
    "plt.legend(['train', 'test'], loc='upper left')\n",
    "plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Tests and error checking\n",
    "\n",
    "The code below here loads the model, runs predictions on validation and unseen data, and saves images with the predictions. This is an attempt to understand where the classifier is doing what is expected (or not)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(\"diagram-sensor-unsure_vgg16-500.hdf5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = load_model(\"checkpoints/ternary_20190911_9748x/diagram-sensor-unsure_vgg16-2000.hdf5\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# was this trying to convert for visualisation?\n",
    "\n",
    "# from tensorflow.python.framework import graph_util\n",
    "# graph_util.convert_variables_to_constants(...)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "all_val_images = []\n",
    "\n",
    "list_of_labels = os.listdir(val_dir)\n",
    "for label in list_of_labels:\n",
    "    current_label_dir_path = os.path.join(val_dir, label)\n",
    "    list_of_images = os.listdir(current_label_dir_path)\n",
    "    for im in list_of_images:\n",
    "        current_image_path = os.path.join(current_label_dir_path, im)\n",
    "#         image = Image.open(current_image_path) # use the function which you want.\n",
    "#         print(current_image_path)\n",
    "        all_val_images.append(current_image_path)\n",
    "print(len(all_val_images))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def load_image(path):\n",
    "    img = image.load_img(path, target_size=model.input_shape[1:3])\n",
    "    x = image.img_to_array(img)\n",
    "    x = np.expand_dims(x, axis=0)\n",
    "    x = preprocess_input(x)\n",
    "    return img, x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Iteratively go through validation images, get the prediction and 'ground truth', save image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(\"predictions\"): \n",
    "    os.mkdir(\"predictions\")\n",
    "\n",
    "classes = [\"diagram\", \"sensor\", \"unsure\"]\n",
    "results = [\"true\", \"false\"]\n",
    "\n",
    "for c in classes:\n",
    "    for r in results:\n",
    "        newdir = os.path.join(\"predictions/\", c + \"_\" + r)\n",
    "        print(newdir)\n",
    "        if not os.path.exists(newdir):\n",
    "            os.mkdir(newdir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [\"diagram\", \"sensor\", \"unsure\"]\n",
    "\n",
    "# iterate over all validation samples\n",
    "for i in range(0, nb_validation_samples):\n",
    "\n",
    "    # get a random filepath\n",
    "#     r = random.randint(0, len(all_val_images))\n",
    "    \n",
    "    # not random anymore\n",
    "    r = i\n",
    "    print(r)\n",
    "    random_file = all_val_images[r]\n",
    "    print(random_file)\n",
    "\n",
    "    img, x = load_image(random_file)\n",
    "    # print(\"shape of x: \", x.shape)\n",
    "    # print(\"data type: \", x.dtype)\n",
    "\n",
    "    # forward the image through the network\n",
    "    predictions = model.predict(x)\n",
    "    # print(predictions)\n",
    "    print(\"diagram: %0.3f | sensor: %0.3f | unsure %0.3f\" % \\\n",
    "          (predictions[0][0], predictions[0][1], predictions[0][2]))\n",
    "\n",
    "    # draw the figure\n",
    "    fig = plt.figure()\n",
    "    plt.imshow(img)\n",
    "    plt.axis('off')\n",
    "    # grab the class from the filepath\n",
    "    c = random_file.rsplit(\"/\")[-2]\n",
    "    label = c\n",
    "\n",
    "    # add the 'ground truth' label and prediction as the title\n",
    "    plt.title(\"class: \" + label + \"\\n\" \\\n",
    "              + \"diagram: %0.3f | sensor: %0.3f | unsure %0.3f\" % \\\n",
    "          (predictions[0][0], predictions[0][1], predictions[0][2]), loc='left')\n",
    "\n",
    "    # save image\n",
    "    # which index is the result\n",
    "    ind = np.argmax(predictions)\n",
    "    pred = classes[ind]\n",
    "    # compare prediction and label\n",
    "    # \n",
    "    result = True if classes.index(label) == ind else False\n",
    "    \n",
    "    savefolder = os.path.join(\"predictions/\", pred + \"_\" + \"true\" if result else pred + \"_\" + \"false\")\n",
    "    \n",
    "    savepath = os.path.join(savefolder, \"prediction_e500_\" + random_file.rsplit(\"/\")[-1])\n",
    "    print(savepath)\n",
    "    fig.savefig(savepath, dpi=150)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "classes = [\"diagram\", \"sensor\", \"unsure\"]\n",
    "\n",
    "print(predictions)\n",
    "ind = np.argmax(predictions)\n",
    "print(ind)\n",
    "pred = classes[ind]\n",
    "print(pred)\n",
    "# compare prediction and label\n",
    "# \n",
    "result = True if classes.index(label) == ind else False\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print(classes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Get a folder of unseen images, run the classifier over them, and save the images with the classification"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "unseen_images = []\n",
    "unseen_dir = os.path.expanduser(\"~/data/images/random/seq/0-100k/unseen/\")\n",
    "\n",
    "list_of_images = os.listdir(unseen_dir)\n",
    "for im in list_of_images:\n",
    "    current_image_path = os.path.join(current_label_dir_path, im)\n",
    "    unseen_images.append(current_image_path)\n",
    "nb_unseen_samples = len(unseen_images)\n",
    "print(len(unseen_images))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "if not os.path.exists(\"predictions_unseen\"): \n",
    "    os.mkdir(\"predictions_unseen\")\n",
    "\n",
    "classes = [\"diagram\", \"sensor\", \"unsure\"]\n",
    "\n",
    "for c in classes:\n",
    "    newdir = os.path.join(\"predictions_unseen/\", c)\n",
    "    print(newdir)\n",
    "    if not os.path.exists(newdir):\n",
    "        os.mkdir(newdir)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get unseen images predictions\n",
    "\n",
    "classes = [\"diagram\", \"sensor\", \"unsure\"]\n",
    "\n",
    "# iterate over all unseen samples\n",
    "# nb_unseen_samples\n",
    "for r in range(0, nb_unseen_samples):\n",
    "\n",
    "    print(r)\n",
    "    random_file = all_val_images[r]\n",
    "    print(random_file)\n",
    "\n",
    "    img, x = load_image(random_file)\n",
    "    # print(\"shape of x: \", x.shape)\n",
    "    # print(\"data type: \", x.dtype)\n",
    "\n",
    "    # forward the image through the network\n",
    "    predictions = model.predict(x)\n",
    "    # print(predictions)\n",
    "    print(\"diagram: %0.3f | sensor: %0.3f | unsure %0.3f\" % \\\n",
    "          (predictions[0][0], predictions[0][1], predictions[0][2]))\n",
    "\n",
    "    # draw the figure\n",
    "    fig = plt.figure()\n",
    "    plt.imshow(img)\n",
    "    plt.axis('off')\n",
    "    # grab the class from the filepath\n",
    "#     c = random_file.rsplit(\"/\")[-2]\n",
    "#     label = c\n",
    "\n",
    "    # add the prediction as the title\n",
    "    plt.title(\"diagram: %0.3f | sensor: %0.3f | unsure %0.3f\" % \\\n",
    "          (predictions[0][0], predictions[0][1], predictions[0][2]), loc='left')\n",
    "\n",
    "    # save image\n",
    "    # which index is the result\n",
    "    ind = np.argmax(predictions)\n",
    "    pred = classes[ind]\n",
    "    # compare prediction and label\n",
    "    # \n",
    "#     result = True if classes.index(label) == ind else False\n",
    "    \n",
    "    savefolder = os.path.join(\"predictions_unseen/\", pred)\n",
    "    \n",
    "    savepath = os.path.join(savefolder, \"prediction_\" + random_file.rsplit(\"/\")[-1])\n",
    "    print(savepath)\n",
    "    fig.savefig(savepath, dpi=150)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from keras.utils import plot_model\n",
    "plot_model(model, to_file='model.png')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
